import os,sys
import numpy as np
from PIL import ImageGrab

# rect    right for x
#     down
#     for
#     y
class image_position():
    def __init__(self):
        self.imgposition = {}
        self.img = None
        self.imgrect = (0,0,0,0)

    def set_image_data(self, xs, ys, xe, ye):
        if xe - xs < 3 or ye - ys < 3:  #   至少一个九宫格截图
            return -3
        img = ImageGrab.grab(bbox=(xs, ys, xe, ye))        #   截图
        self.img = img.getdata()
        self.imgrect = (xs, ys, xe, ye)
        self.xnum = xe - xs
        self.ynum = ye - ys
        return 0
    
    def get_image_rect(self):
        return self.imgrect
    
    def compare_left_right_point(self, x, y, pleft, pright):
        if self.img == None:
            return False
        if self.img[y * self.xnum + x - 1] != pleft:
            return False
        if self.img[y * self.xnum + x + 1] != pright:
            return False
        return True
    
    def compare_up_down_point(self, x, y, pup, pdown):
        if self.img == None:
            return False
        if self.img[(y - 1) * self.xnum + x] != pup:
            return False
        if self.img[(y + 1) * self.xnum + x] != pdown:
            return False
        return True
        
    def get_x_by_3point(self, rgblist, xpoint, yoffset):
        if self.img == None:
            return None
        if len(rgblist) != 3 :
            return None
        uppoint = rgblist[0]
        midpoint = rgblist[1]
        downpoint = rgblist[2]
        if yoffset < 1:
            yoffset = 1
        for y in range(yoffset, self.ynum - 1):
            if self.img[y * self.xnum + xpoint] != midpoint:
                continue
            if self.compare_up_down_point(xpoint, y, uppoint, downpoint) == None:
                continue

            return (xpoint, y)
        return None
    
    def get_y_by_3point(self, rgblist, xoffset, yoffset):
        if self.img == None:
            return None
        if len(rgblist) != 3 :
            return None
        xmax = self.xnum - 1
        if xoffset:
            xmax = xoffset + 1;
        else :
            xoffset = 1
        leftpoint = rgblist[0]
        midpoint = rgblist[1]
        rightpoint = rgblist[2]
        for  y in range(yoffset, self.ynum):
            for x in range(xoffset, xmax):
                if self.img[(y) * self.xnum + x - 1] != leftpoint:
                    continue
                if self.img[(y) * self.xnum + x] != midpoint:
                    continue
                if self.img[(y) * self.xnum + x + 1] != rightpoint:
                    continue
                return (x, y)
        return None
    
    def get_position_by_5point_byx(self, rgblist):
        if self.img == None:
            return None
        if len(rgblist) != 5 :
            return None
        xret = 0
        yret = 0
        while(xret < self.xnum):
            point = self.get_x_by_3point(rgblist[:3], xret, yret)
            if point == None:
                xret = xret + 1
                yret = 0
                continue
            # print('find one ', point)
            if self.compare_left_right_point(point[0], point[1], rgblist[3], rgblist[4]) != True:
                if point[1] >= self.ynum -1:
                    xret = xret + 1
                    yret = 0
                else :
                    yret = point[1] + 1
                continue
            # print('find last ', point)
            return point
        return None
        
    def get_position_by_5point_byy(self, rgblist):
        if self.img == None:
            return None
        if len(rgblist) != 5 :
            return None
        xret = 0
        yret = 0
        yrgblist = []
        yrgblist.append(rgblist[3])
        yrgblist.append(rgblist[1])
        yrgblist.append(rgblist[4])
        while(yret < self.ynum):
            point = self.get_y_by_3point(yrgblist, xret, yret)
            if point == None:
                yret = yret + 1
                xret = 0
                continue
            # print('find one ', point)
            if self.compare_up_down_point(point[0], point[1], rgblist[0], rgblist[2]) != True:
                if point[1] >= self.ynum -1:
                    yret = yret + 1
                    xret = 0
                else :
                    xret = point[1] + 1
                continue
            # print('find last ', point)
            return point
        return None
    
    def get_position_by_6point_byx(self, rgblist):
        if self.img == None:
            return None
        if len(rgblist) != 6 :
            return None
        xret = 0
        yret = 0
        while(xret < self.xnum):
            point = self.get_x_by_3point(rgblist[:3], xret, yret)
            if point == None:
                xret = xret + 1
                yret = 0
                continue
            # print('find one ', point)
            if self.img[point[1] * self.xnum + point[0]+1] != rgblist[4]:   # 先比中间
                if point[1] >= self.ynum -1:
                    xret = xret + 1
                    yret = 0
                else :
                    yret = point[1] + 1
                continue
            if self.compare_up_down_point(point[0]+1, point[1], rgblist[3], rgblist[5]) != True: #再比上下
                if point[1] >= self.ynum -1:
                    xret = xret + 1
                    yret = 0
                else :
                    yret = point[1] + 1
                continue
            # print('find last ', point)
            return point
        return None

    
    def add_img_position(self, name, position):
        self.imgposition[name] = position
        # print(name, " abs ", position[0] + self.imgrect[0], position[1] + self.imgrect[1])
        # for key, value in self.imgposition.items():
        #     print(key, value)
        #     print("\tabs ", value[0] + self.imgrect[0], value[1] + self.imgrect[1])
        
    def print_img_position(self):
        for key, value in self.imgposition.items():
            print(key, "\tabs ", value[0] + self.imgrect[0], value[1] + self.imgrect[1])
            
    def get_abs_position(self, name):
        absx = self.imgposition[name][0] + self.imgrect[0]
        absy = self.imgposition[name][1] + self.imgrect[1]
        return (absx, absy)

